import pandas as pd
import numpy as np
from numba import njit
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LogisticRegression as Mnl
from sklearn.linear_model import LinearRegression as Ols
import itertools
import os
#os.chdir('/Users/Nicholas/Dropbox/bbm')

class BuildData:
    """This class takes the data and cleans it up for use in the
    estimation. The main thing it does is to take the contracts and
    state data (from Stata):

    Inputs:
        - state.dta
        - contract_final_ym.dta

    Returns:
        - data_state.csv
            cleaned up version of state data.
        - data_contracts_clean.csv
            cleaned up version of contracts data.
        - data_contracts_state.csv
            merged contracts and state data.
        - data_mnl.csv
            data that goes into estimating the multinomial logit.

    Overall this class is just a collection of static methods.
    """

    def __init__(self):
        """ This runs all of the data construction steps on startup.
        """
        self.construct_data_state()
        self.construct_data_contracts()
        self.merge_contracts_state()
        self.build_mnl_data()

    @staticmethod
    def construct_data_state():
        """ Cleans up the state data and saves it as data_state.csv.
        """
        df = pd.read_stata("./data/processed/state.dta")

        '''
        df = df[['ym',
                 'gas_price',
                 'gas_price_smooth',
                 'n_available_spec_smooth_low',
                 'n_available_spec_smooth_mid',
                 'n_available_spec_smooth_high',
                 'prob_match1',
                 'prob_match2',
                 'prob_match3',
                 'n_available_spec1',
                 'n_available_spec2',
                 'n_available_spec3', ]]
        '''
        #df['ym'] = pd.to_datetime(df['ym'], unit='ms')
        df.set_index('ym', inplace=True)

        for spec in ['low', 'mid', 'high']:
           df['unemp_' + spec] = df.loc[
               df['rig_spec'] == spec + '-spec',
               'n_unemp'
                ]

           df['util_' + spec] = df.loc[
                df['rig_spec'] == spec + '-spec',
                'utilization'
                ]

           df['n_available_' + spec] = df.loc[
                df['rig_spec'] == spec + '-spec',
                'n_available_spec'
                ]

        df.rename(columns={'n_available_spec_smooth_low': 'n_l',
                           'n_available_spec_smooth_mid': 'n_m',
                           'n_available_spec_smooth_high': 'n_h',
                           'gas_price_smooth': 'g'},
                  inplace=True)

        df = df[df['rig_spec'] == 'high-spec']

        df = df[[
            'gas_price',
            'g',
            'n_l',
            'n_m',
            'n_h',
            'unemp_low',
            'unemp_mid',
            'unemp_high',
            'util_low',
            'util_mid',
            'util_high',
            'n_available_low',
            'n_available_mid',
            'n_available_high'
        ]]

        df.reset_index(inplace=True)
        df = df.round(3)
        df.to_csv('./data/processed/data_state.csv')

    @staticmethod
    def construct_data_contracts():
        """ Cleans up the contracts data and saves it as:
        data_contracts_clean.csv
        """
        beta = 0.99

        # Read in and name the columns
        df = pd.read_stata("./data/processed/contract_final_ym.dta")
        df = df[['name_ihs',
                 'rig_spec',
                 'operator_ihs',
                 'contractor_ihs',
                 'contractstart',
                 'fixturedate',
                 'bid',
                 'dayrate',
                 'waterd',
                 'mri',
                 'duration',
                 'reneg',
                 'type',
                 'max_drill']]

        df.rename(columns={'bid': 'value'}, inplace=True)
        df['value'] = df['value'] / (1000000 * 5.83 * 30)

        # Clean up dates
        df['fixturedate'] = pd.to_datetime(df['fixturedate'],
                                           format='%Y-%m-%d')
        df['contractstart'] = pd.to_datetime(df['contractstart'],
                                             format='%Y-%m-%d')
        df = df[(df['fixturedate'] >= pd.to_datetime('2000-1-1',
                                                    format='%Y-%m-%d'))]
        df['ym'] = df['fixturedate'].apply(lambda dt: dt.replace(day=1))

        # Clean up the covariates
        df['mri'] = df['mri'] / 1000
        df['dayrate'] = df['dayrate'] / 1000000
        df['tau'] = pd.cut(df['duration'],
                           bins=[0, 75, 105, 10000],
                           labels=[2, 3, 4])

        # Compute dayrate total (i.e. go from price per day to total
        # contract price.
        df['dayrate_total'] = np.nan

        df.loc[df['tau'] == 2, 'dayrate_total'] \
            = df['dayrate'] * (1 + beta)
        df.loc[df['tau'] == 3, 'dayrate_total'] \
            = df['dayrate'] * (1 + beta + beta ** 2)
        df.loc[df['tau'] == 4, 'dayrate_total'] \
            = df['dayrate'] * (1 + beta + beta ** 2 + beta ** 3)

        df = df.dropna(subset=['tau'])

        df.to_csv('./data/processed/data_contracts_clean.csv')

    @staticmethod
    def merge_contracts_state():
        """ Merge the contract data with the state data and save it
        as data_contracts_state.csv .

        """
        data_contracts_clean = pd.read_csv(
            './data/processed/data_contracts_clean.csv')
        data_state = pd.read_csv('./data/processed/data_state.csv')

        data_contracts_state = data_contracts_clean.merge(data_state,
                                                          on='ym',
                                                          how='left')

        data_contracts_state.to_csv('./data/processed/data_contracts_state.csv')

    @staticmethod
    def build_mnl_data():
        """ Set up the contracts and state data in a form where it can
        be used to fit the multinomial logit model.

        Essentially it takes the contract data and accounts for when
        the rigs are not matched and assigns a '0' value:
          no match = number available - number under contract
        """
        data_contracts_clean = pd.read_csv(
            'data/processed/data_contracts_clean.csv')
        data_state = pd.read_csv('data/processed/data_state.csv')

        data_tau0 = list()
        data_tau0.append(data_contracts_clean[['ym', 'tau', 'rig_spec']])

        for s, spec in zip(['l', 'm', 'h'], ['low', 'mid', 'high']):
            # 'data' is what will be repeated.
            data = [[0, spec + '-spec']]
            index = data_state['ym'].repeat(
                data_state['unemp_' + spec].dropna()
            )
            columns = ['tau', 'rig_spec']
            data_tau0.append(pd.DataFrame(data=data,
                                          index=index,
                                          columns=columns).reset_index())

        data_mnl = pd.concat(data_tau0,
                             sort=True).reset_index(drop=True)
        data_mnl.to_csv('./data/processed/data_mnl.csv')


class PolicyFunctions:
    """ Estimate the policy functions for the mean dayrate and the
    probability of matching with a particular contract length.

    Steps to use:
        1. Initialize the class with a polynomial order
        2. Access the dictionary of policy functions with self.match .

    """

    def __init__(self, polynomial_order=2):
        """ Initialize the policy functions in several steps.

        Args:
            polynomial_order: the order of the polynomial of states that
            we want to fit. Default = 2.

        Returns:
            The main attribute to access (which is compute upon
            initialization) is self.match .
        """

        # 1. Read in the data
        (contracts_poly,
         mnl_poly,
         poly_names,
         poly) = self.setup_data(polynomial_order)

        # 2. Get the policy function dictionaries
        output_mnl = self.estimate_mnl(mnl_poly=mnl_poly,
                                       poly_names=poly_names,
                                       poly=poly)

        output_ols = self.estimate_ols(contracts_poly=contracts_poly,
                                       poly_names=poly_names,
                                       poly=poly)

        # 3. Put them all into one dictionary
        self.match = {'low': dict(),
                      'mid': dict(),
                      'high': dict()}

        for spec in ['low', 'mid', 'high']:
            self.match[spec] = {**output_mnl[spec + '-spec'],
                                **output_ols[spec + '-spec']}

    @staticmethod
    def setup_data(polynomial_order):
        # Get data
        state = pd.read_csv('./data_py/processed/states.csv')
        mnl = pd.read_csv('./data_py/processed/mnl.csv')
        contracts = pd.read_csv('./data/processed/data_contracts_clean.csv')
        contracts = contracts.rename(columns={'ym': 'month'})
        contracts['rig_spec'] = contracts['rig_spec'].replace(
            {'low-spec': 'low', 'mid-spec': 'mid', 'high-spec': 'high'})
        contracts['month'] = pd.to_datetime(contracts['month']).dt.strftime('%Y-%m')

        # Setup state names
        state_names = ['g', 'n_l', 'n_m', 'n_h']

        # Do polynomial data
        poly = PolynomialFeatures(polynomial_order,
                                  include_bias=False)

        # Get data in the polynomial form
        poly_fit = poly.fit_transform(state[state_names])
        poly_names = poly.get_feature_names(state_names)

        state_poly = pd.DataFrame(poly_fit,
                                  columns=poly_names,
                                  index=state['month'])

        contracts_poly = contracts.merge(state_poly,
                                         on='month',
                                         how='left')

        mnl_poly = mnl.merge(state_poly,
                             on='month',
                             how='left')

        # Save to view
        contracts_poly.to_csv('./data/processed/contracts_poly.csv')
        mnl_poly.to_csv('./data/processed/mnl_poly.csv')

        return contracts_poly, mnl_poly, poly_names, poly

    @staticmethod
    def estimate_mnl(mnl_poly, poly_names, poly):
        """ Estimates a multinomial logit model.

        Args:
            mnl_poly: data in the form for multinomial logit regression
                that includes polynomial state names.
            poly_names: the names of the polynomial state columns in
                the dataframe.
            poly: the polynomial class

        Returns:
            output: a dictionary of dictionaries of functions that fit
                the multinomial logit to the data. This is in the
                form of:
                    output[rig spec]['tau']

        """
        reg = dict()
        df = dict()
        output = {'low': dict(),
                  'mid': dict(),
                  'high': dict()}

        for i, j in itertools.product(['tau'],
                                      ['low',
                                       'mid',
                                       'high']):
            def reg_fn_1(i, j):
                """ This is a closure wrapper to store the values i and
                j.

                Args:
                    i: ['tau']
                    j: rig type in ['low','mid','high']

                Returns:
                    out: a function that stores the i and j and builds
                        a function that will fit the multinomial logit
                        model to state data.

                """
                # Build things for the multinomial logit
                mnl_spec = mnl_poly[(mnl_poly['rig_spec'] == j)]
                #mnl_spec[poly_names] = mnl_spec[poly_names]
                df[i + ' ' + j] = mnl_spec.dropna(subset=[i] + poly_names)

                mnl = Mnl(random_state=0,
                          multi_class='multinomial',
                          solver='newton-cg',
                          max_iter=10000)

                # fn_1 saves the following function:
                reg[i + ' ' + j] = mnl.fit(df[i + ' ' + j][poly_names],
                                           df[i + ' ' + j][i])

                def fn_1(s, vector=False):
                    """ Transform a state to probabilities of matching
                    each type of contract length (note that the prob
                    of matching no contract can be computed as
                    1 - sum(probs of matching)

                    Args:
                        s: array of states with columns [g,n_l,n_m,n_h]

                    Returns:
                        out: array of probability of matching each type
                        of contract length in [2,3,4].

                    """
                    if vector is False:
                        fit = poly.fit_transform(np.array(s).reshape(1, -1))
                    if vector is True:
                        fit = poly.fit_transform(s)
                    out = reg[i + ' ' + j].predict_proba(fit)[:, 1:4][0]
                    return out

                return fn_1

            output[j][i] = reg_fn_1(i, j)

        return output, reg

    @staticmethod
    def estimate_ols(contracts_poly, poly_names, poly):
        """ Estimates a linear model.

        Args:
            contracts_poly: data in the form for a linear regression
                that includes polynomial state names.
            poly_names: the names of the polynomial state columns in
                the dataframe.
            poly: the polynomial class

        Returns:
            output: a dictionary of dictionaries of functions that fit
                the multinomial logit to the data. This is in the
                form of:

                    output[rig spec][type]

                where type is in: [mri,value,dayrate]

        """
        reg = dict()
        output = {'low': dict(),
                  'mid': dict(),
                  'high': dict()}

        for i, j in itertools.product(['mri',
                                       'value',
                                       'dayrate'],
                                      ['low',
                                       'mid',
                                       'high']):
            def reg_fn_1(i, j):
                """ This is a closure wrapper to store the values i and
                j.

                Args:
                    i: [mri,value,dayrate]
                    j: rig type in ['low','mid','high']

                Returns:
                    out: a function that stores the i and j and builds
                        a function that will fit the linear regression
                        model to state data.

                """
                mask = (contracts_poly['rig_spec'] == j)
                df = contracts_poly[mask].dropna(subset=[i] + poly_names + ['tau'])

                # reg is the object that fn_1 saves
                reg[i + ' ' + j] = Ols().fit(df[poly_names + ['tau']],
                                             df[i])

                def fn_1(s, tau):
                    """ Function that does the fitting.

                    Args:
                        tau:
                        s: array of states with columns [g,n_l,n_m,n_h]

                    Returns:
                        out: a function that fits the model.

                    """
                    if len(s.shape) == 1:
                        s = s.reshape(-1, 1).T
                        fit = poly.fit_transform(s)
                        x = np.append(
                            fit,
                            np.array([[tau]]),
                            axis=1
                        )
                    else:
                        fit = poly.fit_transform(s)
                        x = np.append(
                            fit,
                            np.repeat([tau], len(s)).reshape(-1, 1),
                            axis=1
                        )

                    prediction = reg[i + ' ' + j].predict(x)

                    # Ensure prediction is > 0
                    prediction = prediction.clip(min=0)
                    return prediction

                return fn_1, reg

            output[j]['mean_' + i] = reg_fn_1(i, j)

        return output, reg


def build_ols_predictor(coefs):

    @njit
    def ols_predictor(state, tau):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]
        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
        )
        return x
    return ols_predictor


def build_mnl_predictor(coefs):

    @njit
    def mnl_predictor(state):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = np.array([0.0, 0.0, 0.0, 0.0])
        for i in [0, 1, 2, 3]:
            x[i] = (
                coefs[i, 0]
                + coefs[i, 1] * g
                + coefs[i, 2] * n_l
                + coefs[i, 3] * n_m
                + coefs[i, 4] * n_h
                + coefs[i, 5] * g * g
                + coefs[i, 6] * g * n_l
                + coefs[i, 7] * g * n_m
                + coefs[i, 8] * g * n_h
                + coefs[i, 9] * n_l * n_l
                + coefs[i, 10] * n_l * n_m
                + coefs[i, 11] * n_l * n_h
                + coefs[i, 12] * n_m * n_m
                + coefs[i, 13] * n_m * n_h
                + coefs[i, 14] * n_h * n_h
            )

        # NORMALIZATION
        #x = x - np.max(x)

        denom = np.exp(x[0]) + np.exp(x[1]) + np.exp(x[2]) + np.exp(x[3])

        return np.array([
            np.exp(x[0]) / denom,
            np.exp(x[1]) / denom,
            np.exp(x[2]) / denom,
            np.exp(x[3]) / denom
        ])
    return mnl_predictor


def build_prob_extension_predictor(coefs):

    @njit
    def prob_extension_predictor(state, tau):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
        )

        # NORMALIZATION
        #x = x - np.max(x)

        return np.exp(x) / (1 + np.exp(x))

    return prob_extension_predictor


def build_prob_extension_predictor_with_mri(coefs):

    @njit
    def prob_extension_predictor_with_mri(state, tau, mri):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
            + coefs[16] * mri
        )

        # NORMALIZATION
        #x = x - np.max(x)

        return np.exp(x) / (1 + np.exp(x))

    return prob_extension_predictor_with_mri


def build_price_extension_predictor(coefs):

    @njit
    def price_extension_predictor(state, tau):
        g = state[0]
        n_l = state[1]
        n_m = state[2]
        n_h = state[3]

        x = (
            coefs[0]
            + coefs[1] * g
            + coefs[2] * n_l
            + coefs[3] * n_m
            + coefs[4] * n_h
            + coefs[5] * g * g
            + coefs[6] * g * n_l
            + coefs[7] * g * n_m
            + coefs[8] * g * n_h
            + coefs[9] * n_l * n_l
            + coefs[10] * n_l * n_m
            + coefs[11] * n_l * n_h
            + coefs[12] * n_m * n_m
            + coefs[13] * n_m * n_h
            + coefs[14] * n_h * n_h
            + coefs[15] * tau
        )

        # NORMALIZATION
        #x = x - np.max(x)

        return x

    return price_extension_predictor